{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dfsDR_omdNea"
      },
      "source": [
        "# Gemma Basics (Hugging Face)\n",
        "This notebook demonstrates how to load, fine-tune and deploy Gemma model by utilising Hugging Face.\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_1]Basics_with_HF.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FaqZItBdeokU"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**.\n",
        "\n",
        "### Gemma setup\n",
        "\n",
        "**Before we dive into the tutorial, let's get you set up with Gemma:**\n",
        "\n",
        "1. **Hugging Face Account:**  If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n",
        "2. **Gemma Model Access:** Head over to the [Gemma model page](https://huggingface.co/google/gemma-2b) and accept the usage conditions.\n",
        "3. **Colab with Gemma Power:**  For this tutorial, you'll need a Colab runtime with enough resources to handle the Gemma 2B model. Choose an appropriate runtime when starting your Colab session.\n",
        "4. **Hugging Face Token:**  Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.\n",
        "\n",
        "**Once you've completed these steps, you're ready to move on to the next section where we'll set up environment variables in your Colab environment.**\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CY2kGtsyYpHF"
      },
      "source": [
        "### Configure your HF token\n",
        "\n",
        "Add your Hugging Face token to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create a new secret with the name `HF_TOKEN`.\n",
        "3. Copy/paste your token key into the Value input box of `HF_TOKEN`.\n",
        "4. Toggle the button on the left to allow notebook access to the secret.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A9sUQ4WrP-Yr"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iwjo5_Uucxkw"
      },
      "source": [
        "### Install dependencies\n",
        "Run the cell below to install all the required dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r_nXPEsF7UWQ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m401.7/401.7 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.6/251.6 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m302.6/302.6 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m119.8/119.8 MB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m542.0/542.0 kB\u001b[0m \u001b[31m32.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m245.2/245.2 kB\u001b[0m \u001b[31m13.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.4/102.4 kB\u001b[0m \u001b[31m12.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m67.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install --upgrade -q transformers huggingface_hub peft \\\n",
        "  accelerate bitsandbytes datasets trl"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2_bahJBmwvSp"
      },
      "source": [
        "### Log into Hugging Face Hub\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GIFFCHi-wvSp"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
            "Token is valid (permission: write).\n",
            "Your token has been saved to /root/.cache/huggingface/token\n",
            "Login successful\n"
          ]
        }
      ],
      "source": [
        "from huggingface_hub import login\n",
        "\n",
        "login(os.environ[\"HF_TOKEN\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gFLddpGeaKh5"
      },
      "source": [
        "All set and ready to explore the possibilities with Gemma!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yXFZFUJHgTcU"
      },
      "source": [
        "## Instantiate the Gemma 2B model\n",
        "\n",
        "Gemma is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.\n",
        "\n",
        "\n",
        "Let's get started by loading the model from Hugging Face Hub."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jgl8ZjHpwvSq"
      },
      "source": [
        "### Loading the model from HF Hub"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w_z4600bwvSq"
      },
      "outputs": [],
      "source": [
        "model_id = \"google/gemma-1.1-2b-it\"\n",
        "device = \"cuda\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "74tpQWWWwvSq"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8eabd3491fde4488ad214fda12e3d757",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "74178666370244a6abbfe39070ac3715",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "657f11310caf4fd9b66625ffb23e0e47",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "430c79d1285c4feb8ee49a04f9bd8935",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Let's load the tokenizer first\n",
        "from transformers import AutoTokenizer\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UD-eXTxxwvSq"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7bfab9fcc0dd4657b140fec719335c54",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/618 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1bd98f8dcb3c4192956e91862475e625",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1512dae56a3f465887754363a28e76df",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "62df1c005c204779a65f6ecc22416704",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1330a7a3b681427fbcbed3a6ae0aed21",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "69df67841e694898a18e046402e0c7be",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "683d2f59556b4e56a29abd8f395ee9fe",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
        "\n",
        "# Let's quantize the model to reduce its weight\n",
        "bnb_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16\n",
        ")\n",
        "\n",
        "# Let's load the final model\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_id, quantization_config=bnb_config, device_map={\"\": 0}\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lyw7fwOGwvSq"
      },
      "source": [
        "### Trying it out"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nrVBVTtlwvSq"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "My favourite color is blue. It represents calmness, trust, and serenity. It brings me a sense of peace and tranquility\n"
          ]
        }
      ],
      "source": [
        "prompt = \"My favourite color is\"\n",
        "inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
        "outputs = model.generate(inputs, max_new_tokens=20)\n",
        "text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "print(text)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zYT6m2LNvxdo"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "What can you use an LLM for? Answer:\n",
            "\n",
            "**An LLM (Large Language Model) can be used for a wide range of tasks, including:**\n",
            "\n",
            "* **Information retrieval:** Providing summaries, answering questions, and providing factual information.\n",
            "* **Content creation:** Generating creative text formats, writing different kinds of content, and translating languages.\n",
            "* **Summarization:** Extracting key points from large amounts of text.\n",
            "* **Code generation:** Assisting developers in writing code and debugging errors.\n",
            "* **Customer service:** Providing personalized and contextual support to users.\n",
            "* **Education:** Providing personalized learning experiences and generating educational materials.\n",
            "* **Marketing:** Creating targeted marketing campaigns and analyzing customer data.\n",
            "* **Translation:** Translating documents and websites between multiple languages.\n",
            "* **Creative writing:** Generating original and imaginative content.\n"
          ]
        }
      ],
      "source": [
        "prompt = \"What can you use an LLM for? Answer:\"\n",
        "inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
        "outputs = model.generate(inputs, max_new_tokens=512)\n",
        "text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "print(text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QlFBTx33wvSq"
      },
      "source": [
        "## Fine-tuning the model with LoRA\n",
        "\n",
        "This section of the guide focuses on training your Large Language Model (LLM) to generate famous quotations. Here, we will explore the process of fine-tuning your model to enable it to produce outputs similar to renowned authors, philosophers, and leaders."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cuqqe59fwvSq"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'Quote: Imagination is more than just a spark of genius; it is the fertile ground from which great art, science, and'"
            ]
          },
          "execution_count": 10,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Let's try it out before the fine-tuning\n",
        "text = \"Quote: Imagination is more\"\n",
        "inputs = tokenizer(text, return_tensors=\"pt\").to(device)\n",
        "outputs = model.generate(**inputs, max_new_tokens=20)\n",
        "tokenizer.decode(outputs[0], skip_special_tokens=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8_iH8JINwvSr"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "72c981db74b745afb63cb6586d38b417",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading readme:   0%|          | 0.00/5.55k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9e0add45514f4b7a82153ecce74d9c27",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading data:   0%|          | 0.00/647k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f49f9ec5296940b1867d0efc2051c46e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split:   0%|          | 0/2508 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Example item: {'quote': '“Be yourself; everyone else is already taken.”', 'author': 'Oscar Wilde', 'tags': ['be-yourself', 'gilbert-perreira', 'honesty', 'inspirational', 'misattributed-oscar-wilde', 'quote-investigator']}\n"
          ]
        }
      ],
      "source": [
        "# Loading and processing the dataset\n",
        "from datasets import load_dataset\n",
        "\n",
        "data = load_dataset(\"Abirate/english_quotes\")\n",
        "print(\"Example item:\", data[\"train\"][0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "65Cmr3fBw9bN"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f29b1baf6618448a8c0e74e5a7c73115",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/2508 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Let's tokenize the quotes\n",
        "data = data.map(lambda samples: tokenizer(samples[\"quote\"]), batched=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C8RKs_oZwvSr"
      },
      "outputs": [],
      "source": [
        "from peft import LoraConfig\n",
        "\n",
        "# Define tuning parameters\n",
        "lora_config = LoraConfig(\n",
        "    r=8,\n",
        "    task_type=\"CAUSAL_LM\",\n",
        "    target_modules=[\n",
        "        \"q_proj\",\n",
        "        \"o_proj\",\n",
        "        \"k_proj\",\n",
        "        \"v_proj\",\n",
        "        \"gate_proj\",\n",
        "        \"up_proj\",\n",
        "        \"down_proj\",\n",
        "    ],\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m39PDWM3wvSr"
      },
      "outputs": [],
      "source": [
        "def formatting_func(example):\n",
        "    text = f\"Quote: {example['quote'][0]}\\nAuthor: {example['author'][0]}<eos>\"\n",
        "    return [text]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oysHd0jXwvSr"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:246: UserWarning: You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to 1024\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:318: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.\n",
            "  warnings.warn(\n",
            "max_steps is given, it will override any value given in num_train_epochs\n"
          ]
        }
      ],
      "source": [
        "import transformers\n",
        "from trl import SFTTrainer\n",
        "\n",
        "# Create Trainer objects that takes care of the process\n",
        "trainer = SFTTrainer(\n",
        "    model=model,\n",
        "    train_dataset=data[\"train\"],\n",
        "    args=transformers.TrainingArguments(\n",
        "        per_device_train_batch_size=1,\n",
        "        gradient_accumulation_steps=4,\n",
        "        warmup_steps=2,\n",
        "        max_steps=10,\n",
        "        learning_rate=2e-4,\n",
        "        fp16=True,\n",
        "        logging_steps=1,\n",
        "        output_dir=\"outputs\",\n",
        "        optim=\"paged_adamw_8bit\",\n",
        "    ),\n",
        "    peft_config=lora_config,\n",
        "    formatting_func=formatting_func,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yCeOevVHsJGX"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [10/10 00:11, Epoch 0/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>8.429300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>6.572500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>7.176600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>7.020800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>8.272000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>7.685400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>5.656700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>5.468200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>6.232400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>7.083500</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=10, training_loss=6.95973539352417, metrics={'train_runtime': 12.9991, 'train_samples_per_second': 3.077, 'train_steps_per_second': 0.769, 'total_flos': 16634596884480.0, 'train_loss': 6.95973539352417, 'epoch': 0.01594896331738437})"
            ]
          },
          "execution_count": 16,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Let's run the fine-tuning\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y1hPDZgZwvSr"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Quote: Imagination is the faculty of the mind to create new things.\" - Albert Einstein\n",
            "\n",
            "**Answer:**\n",
            "\n",
            "Imagination\n"
          ]
        }
      ],
      "source": [
        "# Testing the models after fine-tuning\n",
        "text = \"Quote: Imagination is\"\n",
        "inputs = tokenizer(text, return_tensors=\"pt\").to(device)\n",
        "\n",
        "outputs = model.generate(**inputs, max_new_tokens=20)\n",
        "print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p3jL-Z8CtqgP"
      },
      "source": [
        "## Push the model to your Hugging Face Hub\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aM84Ti3r02Tz"
      },
      "source": [
        "Hugging Face allow to you easily store trained models in their hub."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HIDWBva0_SX4"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ed94520b9077442a8d3899d70bf31b08",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md:   0%|          | 0.00/5.18k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b5b398b7208d43eba552b34efbae341c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "CommitInfo(commit_url='https://huggingface.co/f33ac/my-gemma-2-finetuned-model/commit/c837075477e241519df9aaf42e6a032b1d2e6df7', commit_message='Upload GemmaForCausalLM', commit_description='', oid='c837075477e241519df9aaf42e6a032b1d2e6df7', pr_url=None, pr_revision=None, pr_num=None)"
            ]
          },
          "execution_count": 18,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Note: The token needs to have \"write\" permission\n",
        "#       You can check it here:\n",
        "#       https://huggingface.co/settings/tokens\n",
        "model.push_to_hub(\"my-gemma-2-finetuned-model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5e-PkWR6wvSr"
      },
      "source": [
        "## Serve you model using Text Generation Inference (TGI)\n",
        "\n",
        "Text Generation Inference is a toolkit that simplifies deploying and using large language models (LLMs) like Gemma. It optimizes models for text generation tasks, enabling them to run faster and produce results quicker. TGI achieves this through techniques like tensor parallelism, which distributes the workload across multiple graphics cards (GPUs) for faster processing, and optimized code specifically designed for text generation. Additionally, TGI offers features that make it suitable for production environments, such as distributed tracing for monitoring model performance, Prometheus metrics for detailed data collection, and security measures like watermarking to protect model outputs. You can read more about TGI by referring to [the official documentation](https://huggingface.co/docs/text-generation-inference/en/index)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bmx6iT6xp0RI"
      },
      "source": [
        "To deploy your model with TGI you can either:\n",
        "\n",
        "1. **Deploy it locally (requires Docker):** Uncomment the code cells below to run the model on your local machine. This approach requires Docker to be installed and GPU attached.\n",
        "\n",
        "2. **Deploy it on Google Cloud Platform using GKE:** Follow this guide [Serve Gemma open models using GPUs on GKE with Hugging Face TGI](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-gpu-tgi) to deploy your model on Google Cloud's CKE service. This option leverages GPUs for high-performance inference.\n",
        "\n",
        "Both deployment methods will provide you with an HTTP endpoint for sending requests and receiving text generation responses from your model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0wEjhtJawvSr"
      },
      "outputs": [],
      "source": [
        "!model=\"google/gemma-1.1-2b-it\" # ID of the model in Hugging Face hub\n",
        "# (you can use your own fine-tuned model from\n",
        "# the previous step)\n",
        "!volume=$PWD/data               # Shared directory with the Docker container\n",
        "# to avoid downloading weights every run\n",
        "\n",
        "# !docker run --gpus all --shm-size 1g -p 8080:80 \\\n",
        "#     -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.3 \\\n",
        "#     --model-id $model"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_1]Basics_with_HF.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
